-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda]==0.3.2
absl-py>=0.12.0
numpy>=1.21.5
tensorflow-cpu==2.6.0  # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets>=4.4.0
matplotlib>=3.5.0
clu==0.0.3
flax==0.3.5
chex==0.0.7
optax==0.1.0
ml-collections==0.1.0
scikit-image
waymo-open-dataset-tf-2-6-0